#include "src/arm_common/pooling/algo.h"
#include "megdnn/opr_param_defs.h"
#include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_int8.h"
#include "src/arm_common/pooling/do_max_pooling_w2x2_s2x2.h"
#include "src/arm_common/pooling/do_max_pooling_w4x4_s2x2.h"
#include "src/arm_common/pooling/do_pooling_2x2_nchw44.h"
#include "src/arm_common/pooling/do_pooling_3x3_nchw44.h"
#include "src/arm_common/pooling/do_pooling_4x4_nchw44.h"
#include "src/arm_common/pooling/do_pooling_5x5_nchw44.h"

#include "midout.h"

MIDOUT_DECL(megdnn_arm_common_pooling)

namespace megdnn {
namespace arm_common {

WorkspaceBundle get_bundle(const PoolingImpl::PoolingKernSizeParam& param) {
    megdnn_assert(
            (param.src_type.category() == DTypeCategory::FLOAT ||
             param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
             param.src_type.enumv() == DTypeEnum::Quantized8Asymm ||
             param.src_type == dtype::Int8{}) &&
            param.format == param::Pooling::Format::NCHW &&
            (param.mode == param::Pooling::Mode::MAX ||
             (param.mode == param::Pooling::Mode::AVERAGE && param.filter[0] == 3)) &&
            param.filter[0] == param.filter[1] &&
            (param.filter[0] == 3 || param.filter[1] == 5) && param.stride[0] == 2 &&
            param.stride[1] == 2 && param.isz[0] >= 2 && param.isz[1] >= 2);
    //! max pooling nxn stride 2
    auto IW = param.isz[1];
    auto OW = param.osz[1];

    // In order to process odd size filter,
    // Firstly, Store a row of the input separately by odd and even numbers
    // Then process them, get a row of the outputs
    // We need to store n rows of results
    SmallVector<size_t> needed_mem;
    for (size_t i = 0; i < param.filter[0]; ++i)
        needed_mem.push_back(OW * param.src_type.size());
    needed_mem.push_back((IW + 1) / 2 * param.src_type.size());
    needed_mem.push_back((IW + 1) / 2 * param.src_type.size());
    WorkspaceBundle ws(nullptr, needed_mem, 16);
    return ws;
}

WorkspaceBundle get_bundle_nchw44(const PoolingImpl::PoolingKernSizeParam& param) {
    megdnn_assert(
            (param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
             param.src_type.enumv() == DTypeEnum::Int8) &&
            (param.format == param::Pooling::Format::NCHW44));
    auto IH = param.isz[0];
    auto IW = param.isz[1];
    auto PH = param.padding[0];
    auto PW = param.padding[1];
    size_t padding_size = 0;
    if ((PH != 0) || (PW != 0)) {
        padding_size = (IW + 2 * PW) * (IH + 2 * PH) * 4 * sizeof(int8_t);
    }
    return WorkspaceBundle(nullptr, {padding_size});
}

const int8_t* handle_padding(
        const int8_t* src, size_t IH, size_t IW, size_t& IH2, size_t& IW2, size_t PH,
        size_t PW, const WorkspaceBundle& ws, bool is_max_mode) {
    int8_t* sptr_base = nullptr;
    int8_t padding_value = is_max_mode ? INT8_MIN : 0;
    bool need_pad = ((PH != 0) || (PW != 0)) ? true : false;
    if (need_pad) {
        IH2 = IH + 2 * PH;
        IW2 = IW + 2 * PW;
        sptr_base = static_cast<int8_t*>(ws.get(0));
        memset(sptr_base, padding_value, sizeof(int8_t) * IH2 * IW2 * 4);
        rep(ih, IH) {
            std::memcpy(
                    sptr_base + (ih + PH) * IW2 * 4 + PW * 4, src + ih * IW * 4,
                    sizeof(int8_t) * IW * 4);
        }
    } else {
        IH2 = IH;
        IW2 = IW;
    }
    return need_pad ? sptr_base : src;
}

bool PoolingImpl::AlgoFilterxModexStride1::usable(
        const PoolingKernSizeParam& param) const {
    auto SH = param.stride[0];
    auto SW = param.stride[1];
    auto FH = param.filter[0];
    auto FW = param.filter[1];

    bool avaible = (param.src_type.category() == DTypeCategory::FLOAT ||
                    param.src_type.category() == DTypeCategory::QUANTIZED) &&
                   param.format == Param::Format::NCHW && SH == 1 && SW == 1 &&
                   FH == FW && (FH == 2 || FH == 3);
    bool is_mode_ok = (param.mode == Mode::MAX || param.mode == Mode::AVERAGE);
    return avaible && is_mode_ok;
}

void PoolingImpl::AlgoFilterxModexStride1::exec(const PoolingKernParam& param) const {
    auto IH = param.isz[0], IW = param.isz[1];
    auto OH = param.osz[0], OW = param.osz[1];
    auto N = param.n, C = param.ic;
    auto PH = param.padding[0];
    auto PW = param.padding[1];
    auto FH = param.filter[0];

    auto src_ptr = param.src_ptr;
    auto dst_ptr = param.dst_ptr;

#define DISPATCH_FUNC(Pooler, NeonPooler, window, midout_type_id)                    \
    MIDOUT_BEGIN(                                                                    \
            megdnn_arm_common_pooling, midout_iv(0), midout_iv(midout_type_id),      \
            Pooler::MIDOUT_CASE_NUM, NeonPooler::MIDOUT_CASE_NUM, window) {          \
        auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr,                     \
                    src_dtype = param.src_type](size_t index, size_t) {              \
            size_t n = index / C;                                                    \
            size_t c = index % C;                                                    \
            do_pooling_compact<Pooler MEGDNN_COMMA NeonPooler MEGDNN_COMMA window>(  \
                    static_cast<const typename Pooler::ctype*>(src_ptr.get_ptr()) +  \
                            n * C * IH * IW + c * IH * IW,                           \
                    static_cast<typename Pooler::ctype*>(dst_ptr.get_ptr()) +        \
                            n * C * OH * OW + c * OH * OW,                           \
                    src_dtype, IH, IW, OH, OW, PH, PW);                              \
        };                                                                           \
        MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN(                                       \
                static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \
    }                                                                                \
    MIDOUT_END()

#define DISPATCH_WINDOW(Pooler, NeonPooler, dtype, ctype, comp_type, midout_type_id) \
    switch (FH) {                                                                    \
        case 2: {                                                                    \
            using _Pooler = Pooler<4, dtype, ctype, comp_type>;                      \
            using _NeonPooler = NeonPooler<4, dtype, ctype, comp_type>;              \
            DISPATCH_FUNC(_Pooler, _NeonPooler, 2, midout_type_id);                  \
            break;                                                                   \
        }                                                                            \
        case 3: {                                                                    \
            using _Pooler = Pooler<9, dtype, ctype, comp_type>;                      \
            using _NeonPooler = NeonPooler<9, dtype, ctype, comp_type>;              \
            DISPATCH_FUNC(_Pooler, _NeonPooler, 3, midout_type_id);                  \
            break;                                                                   \
        }                                                                            \
        default:                                                                     \
            megdnn_assert(0, "unsupport pooling filter size");                       \
            break;                                                                   \
    }

#define DISPATCH_MODE(dtype, ctype, comp_type, midout_type_id)             \
    switch (param.mode) {                                                  \
        case Mode::MAX:                                                    \
            DISPATCH_WINDOW(                                               \
                    MaxPooler, NeonMaxPooler, dtype, ctype, comp_type,     \
                    midout_type_id);                                       \
            break;                                                         \
        case Mode::AVERAGE:                                                \
            DISPATCH_WINDOW(                                               \
                    MeanInPooler, NeonMeanPooler, dtype, ctype, comp_type, \
                    midout_type_id);                                       \
            break;                                                         \
        default:                                                           \
            megdnn_assert(0, "unsupport pooling mode");                    \
            break;                                                         \
    }

    if (param.src_type == dtype::Float32{}) {
        DISPATCH_MODE(dt_float32, float, float, 0);
    } else if (param.src_type.enumv() == DTypeEnum::QuantizedS8) {
        DISPATCH_MODE(dt_qint8, int8_t, float, 1);
    } else if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) {
        DISPATCH_MODE(dt_quint8, uint8_t, float, 2);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
    } else if (param.src_type == dtype::Float16{}) {
        DISPATCH_MODE(dt_float16, __fp16, __fp16, 3);
#endif
    }
#undef DISPATCH_FUNC
#undef DISPATCH_WINDOW
#undef DISPATCH_MODE
}
bool PoolingImpl::AlgoFilter2ModexStride2::usable(
        const PoolingKernSizeParam& param) const {
    auto SH = param.stride[0];
    auto SW = param.stride[1];
    auto FH = param.filter[0];
    auto FW = param.filter[1];

    bool avaible = (param.src_type.category() == DTypeCategory::FLOAT ||
                    param.src_type.category() == DTypeCategory::QUANTIZED) &&
                   param.format == Param::Format::NCHW && FH == FW && SH == SW &&
                   FH == 2 && SH == 2;
    bool is_mode_ok = (param.mode == Mode::MAX || param.mode == Mode::AVERAGE);
    return avaible && is_mode_ok;
}

void PoolingImpl::AlgoFilter2ModexStride2::exec(const PoolingKernParam& param) const {
    auto IH = param.isz[0], IW = param.isz[1];
    auto OH = param.osz[0], OW = param.osz[1];
    auto N = param.n, C = param.ic;
    auto PH = param.padding[0];
    auto PW = param.padding[1];

    auto src_ptr = param.src_ptr;
    auto dst_ptr = param.dst_ptr;
#define DISPATCH_FUNC(Pooler, mode, midout_type_id)                                  \
    MIDOUT_BEGIN(                                                                    \
            megdnn_arm_common_pooling, midout_iv(1), midout_iv(midout_type_id),      \
            Pooler::MIDOUT_CASE_NUM) {                                               \
        auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr,                     \
                    src_dtype = param.src_type](size_t index, size_t) {              \
            size_t n = index / C;                                                    \
            size_t c = index % C;                                                    \
            do_pooling_2x2<Pooler MEGDNN_COMMA mode>(                                \
                    static_cast<const typename Pooler::ctype*>(src_ptr.get_ptr()) +  \
                            n * C * IH * IW + c * IH * IW,                           \
                    static_cast<typename Pooler::ctype*>(dst_ptr.get_ptr()) +        \
                            n * C * OH * OW + c * OH * OW,                           \
                    src_dtype, IH, IW, OH, OW, PH, PW);                              \
        };                                                                           \
        MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN(                                       \
                static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \
    }                                                                                \
    MIDOUT_END()

#define DISPATCH_MODE(dtype, ctype, comp_type, midout_type_id)        \
    switch (param.mode) {                                             \
        case Mode::MAX: {                                             \
            using _Pooler = MaxPooler<4, dtype, ctype, comp_type>;    \
            DISPATCH_FUNC(_Pooler, Mode::MAX, midout_type_id);        \
            break;                                                    \
        }                                                             \
        case Mode::AVERAGE: {                                         \
            using _Pooler = MeanInPooler<4, dtype, ctype, comp_type>; \
            DISPATCH_FUNC(_Pooler, Mode::AVERAGE, midout_type_id);    \
            break;                                                    \
        }                                                             \
        default:                                                      \
            megdnn_assert(0, "unsupport pooling mode");               \
            break;                                                    \
    }

    if (param.src_type == dtype::Float32{}) {
        DISPATCH_MODE(dt_float32, float, float, 0);
    } else if (param.src_type.enumv() == DTypeEnum::QuantizedS8) {
        DISPATCH_MODE(dt_qint8, int8_t, float, 1);
    } else if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) {
        DISPATCH_MODE(dt_quint8, uint8_t, float, 2);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
    } else if (param.src_type == dtype::Float16{}) {
        DISPATCH_MODE(dt_float16, __fp16, __fp16, 3);
#endif
    }
#undef DISPATCH_FUNC
#undef DISPATCH_PAD
#undef DISPATCH_MODE
}

bool PoolingImpl::AlgoFilter3MaxStride2::usable(
        const PoolingKernSizeParam& param) const {
    bool avaible = (param.src_type.category() == DTypeCategory::FLOAT ||
                    param.src_type.category() == DTypeCategory::QUANTIZED) &&
                   param.format == Param::Format::NCHW && param.mode == Mode::MAX &&
                   param.filter[0] == 3 && param.filter[1] == 3 &&
                   param.stride[0] == 2 && param.stride[1] == 2 && param.isz[0] >= 2 &&
                   param.isz[1] >= 2;
    return avaible;
}

void PoolingImpl::AlgoFilter3MaxStride2::exec(const PoolingKernParam& param) const {
    auto IH = param.isz[0], IW = param.isz[1];
    auto OH = param.osz[0], OW = param.osz[1];
    auto N = param.n, C = param.ic;
    auto PH = param.padding[0];
    auto PW = param.padding[1];

    auto src_ptr = param.src_ptr;
    auto dst_ptr = param.dst_ptr;

#define DISPATCH_FUNC(type, func, midout_type_id)                                      \
    MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), midout_iv(midout_type_id)) { \
        WorkspaceBundle wbundle = get_bundle(param);                                   \
        auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, wbundle = wbundle,    \
                    workspace_ptr = param.workspace<dt_byte>()](                       \
                           size_t index, size_t thread_id) {                           \
            auto ws = wbundle;                                                         \
            ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id);              \
            size_t n = index / C;                                                      \
            size_t c = index % C;                                                      \
            do_max_pooling_3x3_s2x2_##func##_NEON(                                     \
                    static_cast<const type*>(src_ptr.get_ptr()) + n * C * IH * IW +    \
                            c * IH * IW,                                               \
                    static_cast<type*>(dst_ptr.get_ptr()) + n * C * OH * OW +          \
                            c * OH * OW,                                               \
                    IH, IW, OH, OW, PH, PW, ws);                                       \
        };                                                                             \
        MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN(                                         \
                static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run);   \
    }                                                                                  \
    MIDOUT_END();

    if (param.src_type == dtype::Float32{}) {
        DISPATCH_FUNC(float, float, 0);
    } else if (param.src_type.enumv() == DTypeEnum::QuantizedS8) {
        DISPATCH_FUNC(int8_t, int8, 1);
    } else if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) {
        DISPATCH_FUNC(uint8_t, uint8, 2);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
    } else if (param.src_type == dtype::Float16{}) {
        DISPATCH_FUNC(__fp16, float16, 3);
#endif
    }
#undef DISPATCH_FUNC
}
bool PoolingImpl::AlgoFilter3AverageStride2::usable(
        const PoolingKernSizeParam& param) const {
    bool avaible = (param.src_type.category() == DTypeCategory::FLOAT) &&
                   param.format == Param::Format::NCHW && param.mode == Mode::AVERAGE &&
                   param.filter[0] == 3 && param.filter[1] == 3 &&
                   param.stride[0] == 2 && param.stride[1] == 2 && param.isz[0] >= 2 &&
                   param.isz[1] >= 2;
    return avaible;
}

void PoolingImpl::AlgoFilter3AverageStride2::exec(const PoolingKernParam& param) const {
    auto IH = param.isz[0], IW = param.isz[1];
    auto OH = param.osz[0], OW = param.osz[1];
    auto N = param.n, C = param.ic;
    auto PH = param.padding[0];
    auto PW = param.padding[1];

    auto src_ptr = param.src_ptr;
    auto dst_ptr = param.dst_ptr;

#define DISPATCH_FUNC(type, MEGDNN_SIMD_WIDTH, midout_type_id)                         \
    MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(3), midout_iv(midout_type_id)) { \
        WorkspaceBundle wbundle = get_bundle(param);                                   \
        auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, wbundle = wbundle,    \
                    workspace_ptr = param.workspace<dt_byte>()](                       \
                           size_t index, size_t thread_id) {                           \
            auto ws = wbundle;                                                         \
            ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id);              \
            size_t n = index / C;                                                      \
            size_t c = index % C;                                                      \
            do_average_pooling_3x3_s2x2_NEON(                                          \
                    static_cast<const type*>(src_ptr.get_ptr()) + n * C * IH * IW +    \
                            c * IH * IW,                                               \
                    static_cast<type*>(dst_ptr.get_ptr()) + n * C * OH * OW +          \
                            c * OH * OW,                                               \
                    IH, IW, OH, OW, PH, PW, ws, MEGDNN_SIMD_WIDTH);                    \
        };                                                                             \
        MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN(                                         \
                static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run);   \
    }                                                                                  \
    MIDOUT_END();
    if (param.src_type == dtype::Float32{}) {
        DISPATCH_FUNC(dt_float32, 4, 0);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
    } else if (param.src_type == dtype::Float16{}) {
        DISPATCH_FUNC(__fp16, 8, 1);
#endif
    }
#undef DISPATCH_FUNC
}
bool PoolingImpl::AlgoFilter4MaxStride2::usable(
        const PoolingKernSizeParam& param) const {
    auto SH = param.stride[0];
    auto SW = param.stride[1];
    auto FH = param.filter[0];
    auto FW = param.filter[1];
    auto OH = param.osz[0], OW = param.osz[1];

    bool avaible = (param.src_type.category() == DTypeCategory::FLOAT ||
                    param.src_type.category() == DTypeCategory::QUANTIZED) &&
                   param.format == Param::Format::NCHW && param.mode == Mode::MAX &&
                   FH == 4 && FW == 4 && SH == 2 && SW == 2 && OH >= 2 && OW >= 2;
    return avaible;
}

void PoolingImpl::AlgoFilter4MaxStride2::exec(const PoolingKernParam& param) const {
    auto IH = param.isz[0], IW = param.isz[1];
    auto OH = param.osz[0], OW = param.osz[1];
    auto N = param.n, C = param.ic;
    auto PH = param.padding[0];
    auto PW = param.padding[1];

    auto src_ptr = param.src_ptr;
    auto dst_ptr = param.dst_ptr;

#define DISPATCH_FUNC(type, func, midout_type_id)                                      \
    MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(4), midout_iv(midout_type_id)) { \
        auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr,                       \
                    src_dtype = param.src_type](size_t index, size_t) {                \
            size_t n = index / C;                                                      \
            size_t c = index % C;                                                      \
            do_max_pooling_w4x4_s2x2_##func##_NEON(                                    \
                    static_cast<const type*>(src_ptr.get_ptr()) + n * C * IH * IW +    \
                            c * IH * IW,                                               \
                    static_cast<type*>(dst_ptr.get_ptr()) + n * C * OH * OW +          \
                            c * OH * OW,                                               \
                    src_dtype, IH, IW, OH, OW, PH, PW);                                \
        };                                                                             \
        MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN(                                         \
                static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run);   \
    }                                                                                  \
    MIDOUT_END();

    if (param.src_type == dtype::Float32{}) {
        DISPATCH_FUNC(float, float, 0);
    } else if (param.src_type.enumv() == DTypeEnum::QuantizedS8) {
        DISPATCH_FUNC(int8_t, int8, 1);
    } else if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) {
        DISPATCH_FUNC(uint8_t, uint8, 2);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
    } else if (param.src_type == dtype::Float16{}) {
        DISPATCH_FUNC(__fp16, float16, 3);
#endif
    }
#undef DISPATCH_FUNC
}
bool PoolingImpl::AlgoFilter5MaxStride2::usable(
        const PoolingKernSizeParam& param) const {
    auto SH = param.stride[0];
    auto SW = param.stride[1];
    auto FH = param.filter[0];
    auto FW = param.filter[1];
    auto OH = param.osz[0], OW = param.osz[1];

    bool avaible = (param.src_type.category() == DTypeCategory::FLOAT ||
                    param.src_type.category() == DTypeCategory::QUANTIZED) &&
                   param.format == Param::Format::NCHW && param.mode == Mode::MAX &&
                   FH == 5 && FW == 5 && SH == 2 && SW == 2 && OH >= 2 && OW >= 2;
    return avaible;
}

void PoolingImpl::AlgoFilter5MaxStride2::exec(const PoolingKernParam& param) const {
    auto IH = param.isz[0], IW = param.isz[1];
    auto OH = param.osz[0], OW = param.osz[1];
    auto N = param.n, C = param.ic;
    auto PH = param.padding[0];
    auto PW = param.padding[1];

    auto src_ptr = param.src_ptr;
    auto dst_ptr = param.dst_ptr;

#define DISPATCH_FUNC(dtype, type, midout_type_id, MEGDNN_SIMD_WIDTH)                  \
    MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(5), midout_iv(midout_type_id)) { \
        WorkspaceBundle wbundle = get_bundle(param);                                   \
        auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, wbundle = wbundle,    \
                    workspace_ptr = param.workspace<dt_byte>()](                       \
                           size_t index, size_t thread_id) {                           \
            auto ws = wbundle;                                                         \
            ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id);              \
            size_t n = index / C;                                                      \
            size_t c = index % C;                                                      \
            do_max_pooling_w5x5_s2x2_NEON<dtype>(                                      \
                    static_cast<const type*>(src_ptr.get_ptr()) + n * C * IH * IW +    \
                            c * IH * IW,                                               \
                    static_cast<type*>(dst_ptr.get_ptr()) + n * C * OH * OW +          \
                            c * OH * OW,                                               \
                    IH, IW, OH, OW, PH, PW, ws, MEGDNN_SIMD_WIDTH);                    \
        };                                                                             \
        MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN(                                         \
                static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run);   \
    }                                                                                  \
    MIDOUT_END();

    if (param.src_type == dtype::Float32{}) {
        DISPATCH_FUNC(dt_float32, float, 0, 4);
    } else if (param.src_type.enumv() == DTypeEnum::QuantizedS8) {
        DISPATCH_FUNC(dt_int8, int8_t, 1, 16);
    } else if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) {
        DISPATCH_FUNC(dt_uint8, uint8_t, 2, 16);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
    } else if (param.src_type == dtype::Float16{}) {
        DISPATCH_FUNC(dt_float16, __fp16, 3, 8);
#endif
    }
#undef DISPATCH_FUNC
}

bool PoolingImpl::AlgoInt8Filter2MaxStride2::usable(
        const PoolingKernSizeParam& param) const {
    auto SH = param.stride[0];
    auto SW = param.stride[1];
    auto FH = param.filter[0];
    auto FW = param.filter[1];
    auto PH = param.padding[0];
    auto PW = param.padding[1];

    bool avaible = param.src_type == dtype::Int8() &&
                   param.format == Param::Format::NCHW && param.mode == Mode::MAX &&
                   SH == 2 && SW == 2 && PH == 0 && PW == 0 && FH == 2 && FW == 2;
    return avaible;
}

void PoolingImpl::AlgoInt8Filter2MaxStride2::exec(const PoolingKernParam& param) const {
    auto IH = param.isz[0], IW = param.isz[1];
    auto OH = param.osz[0], OW = param.osz[1];
    auto N = param.n, C = param.ic;

    auto src_ptr = param.src<dt_int8>();
    auto dst_ptr = param.dst<dt_int8>();

    MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(6)) {
        auto run = [C, IH, IW, OH, OW, src_ptr, dst_ptr](size_t index, size_t) {
            size_t n = index / C;
            size_t c = index % C;
            pooling_max_w2x2_s2x2(
                    src_ptr + n * C * IH * IW + c * IH * IW,
                    dst_ptr + n * C * OH * OW + c * OH * OW, 1, 1, IH, IW, OH, OW);
        };
        MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN(
                static_cast<::megdnn::naive::HandleImpl*>(param.handle), N * C, run);
    }
    MIDOUT_END();
}

bool PoolingImpl::AlgoInt8Filter3MaxStride2::usable(
        const PoolingKernSizeParam& param) const {
    auto SH = param.stride[0];
    auto SW = param.stride[1];
    auto FH = param.filter[0];
    auto FW = param.filter[1];
    auto IH = param.isz[0];
    auto IW = param.isz[1];

    bool avaible = param.src_type == dtype::Int8() &&
                   param.format == Param::Format::NCHW && param.mode == Mode::MAX &&
                   FH == 3 && FW == 3 && SH == 2 && SW == 2 && IH >= 2 && IW >= 2;
    return avaible;
}

void PoolingImpl::AlgoInt8Filter3MaxStride2::exec(const PoolingKernParam& param) const {
    auto IH = param.isz[0], IW = param.isz[1];
    auto OH = param.osz[0], OW = param.osz[1];
    auto N = param.n, C = param.ic;
    auto PH = param.padding[0];
    auto PW = param.padding[1];

    auto src_ptr = param.src<dt_int8>();
    auto dst_ptr = param.dst<dt_int8>();

    MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(7)) {
        WorkspaceBundle wbundle = get_bundle(param);
        auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, wbundle = wbundle,
                    workspace_ptr = param.workspace<dt_byte>()](
                           size_t index, size_t thread_id) {
            auto ws = wbundle;
            ws.set(workspace_ptr + thread_id * ws.total_size_in_bytes());
            size_t n = index / C;
            size_t c = index % C;
            do_max_pooling_3x3_s2x2_int8_NEON(
                    src_ptr + n * C * IH * IW + c * IH * IW,
                    dst_ptr + n * C * OH * OW + c * OH * OW, IH, IW, OH, OW, PH, PW,
                    ws);
        };
        MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN(
                static_cast<::megdnn::naive::HandleImpl*>(param.handle), N * C, run);
    }
    MIDOUT_END();
}

bool PoolingImpl::AlgoFilter3ModexStridexNCHW44::usable(
        const PoolingKernSizeParam& param) const {
    auto SH = param.stride[0];
    auto SW = param.stride[1];
    auto FH = param.filter[0];
    auto FW = param.filter[1];

    bool avaible = (param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
                    param.src_type.enumv() == DTypeEnum::Int8) &&
                   param.format == Param::Format::NCHW44 &&
                   (param.mode == Mode::MAX || param.mode == Mode::AVERAGE) &&
                   FH == 3 && FW == 3 && SW == SH && (SH == 1 || SW == 2);
    //! Int8 not support average, because its round mode is different form
    //! qint8
    avaible &=
            !(param.src_type.enumv() == DTypeEnum::Int8 && param.mode == Mode::AVERAGE);
    return avaible;
}

void PoolingImpl::AlgoFilter3ModexStridexNCHW44::exec(
        const PoolingKernParam& param) const {
    auto IH = param.isz[0], IW = param.isz[1];
    auto OH = param.osz[0], OW = param.osz[1];
    auto N = param.n, C = param.ic;
    auto PH = param.padding[0];
    auto PW = param.padding[1];
    auto SW = param.stride[0];

    auto src_ptr = param.src_ptr;
    auto dst_ptr = param.dst_ptr;

#define DISPATCH_FUNC(type, func, i, mode)                                           \
    MIDOUT_BEGIN(                                                                    \
            megdnn_arm_common_pooling, midout_iv(8), midout_iv(#type #i##_hash)) {   \
        WorkspaceBundle wbundle = get_bundle_nchw44(param);                          \
        auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, wbundle = wbundle,  \
                    workspace_ptr = param.workspace<dt_byte>()](                     \
                           size_t index, size_t thread_id) {                         \
            auto ws = wbundle;                                                       \
            ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id);            \
            size_t n = index / C;                                                    \
            size_t c = index % C;                                                    \
            do_##mode##_pooling_3x3_stride##i##_##func##_nchw44_NEON(                \
                    static_cast<const type*>(src_ptr.get_ptr()) +                    \
                            n * C * IH * IW * 4 + c * IH * IW * 4,                   \
                    static_cast<type*>(dst_ptr.get_ptr()) + n * C * OH * OW * 4 +    \
                            c * OH * OW * 4,                                         \
                    IH, IW, OH, OW, PH, PW, ws);                                     \
        };                                                                           \
        MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN(                                       \
                static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \
    }                                                                                \
    MIDOUT_END();

#define DISPATCH_MODE(type, func, stride)                                              \
    switch (param.mode) {                                                              \
        case Mode::MAX: {                                                              \
            DISPATCH_FUNC(type, func, stride, max);                                    \
            break;                                                                     \
        }                                                                              \
        case Mode::AVERAGE: {                                                          \
            DISPATCH_FUNC(type, func, stride, avg);                                    \
            break;                                                                     \
        }                                                                              \
        default:                                                                       \
            megdnn_throw(                                                              \
                    ssprintf(                                                          \
                            "Unsupport pooling mode %d", static_cast<int>(param.mode)) \
                            .c_str());                                                 \
    }

#define DISPATCH_STRIDE(type, func)                                         \
    switch (SW) {                                                           \
        case 1: {                                                           \
            DISPATCH_MODE(type, func, 1);                                   \
            break;                                                          \
        }                                                                   \
        case 2: {                                                           \
            DISPATCH_MODE(type, func, 2);                                   \
            break;                                                          \
        }                                                                   \
        default:                                                            \
            megdnn_throw(ssprintf("Unsupport stride size %d", SW).c_str()); \
    }

    DISPATCH_STRIDE(int8_t, int8);

#undef DISPATCH_STRIDE
#undef DISPATCH_MODE
#undef DISPATCH_FUNC
}

bool PoolingImpl::AlgoFilter2ModexStridexNCHW44::usable(
        const PoolingKernSizeParam& param) const {
    auto SH = param.stride[0];
    auto SW = param.stride[1];
    auto FH = param.filter[0];
    auto FW = param.filter[1];

    bool avaible = (param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
                    param.src_type.enumv() == DTypeEnum::Int8) &&
                   param.format == Param::Format::NCHW44 &&
                   (param.mode == Mode::MAX || param.mode == Mode::AVERAGE) &&
                   FH == 2 && FW == 2 && SH == SW && (SW == 1 || SW == 2);
    //! Int8 not support average, because its round mode is different form
    //! qint8
    avaible &=
            !(param.src_type.enumv() == DTypeEnum::Int8 && param.mode == Mode::AVERAGE);
    return avaible;
}

void PoolingImpl::AlgoFilter2ModexStridexNCHW44::exec(
        const PoolingKernParam& param) const {
    auto IH = param.isz[0], IW = param.isz[1];
    auto OH = param.osz[0], OW = param.osz[1];
    auto N = param.n, C = param.ic;
    auto PH = param.padding[0];
    auto PW = param.padding[1];
    auto SW = param.stride[0];

    auto src_ptr = param.src_ptr;
    auto dst_ptr = param.dst_ptr;

#define DISPATCH_FUNC(type, func, i, mode)                                           \
    MIDOUT_BEGIN(                                                                    \
            megdnn_arm_common_pooling, midout_iv(9), midout_iv(#func #i##_hash)) {   \
        WorkspaceBundle wbundle = get_bundle_nchw44(param);                          \
        auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, wbundle = wbundle,  \
                    workspace_ptr = param.workspace<dt_byte>()](                     \
                           size_t index, size_t thread_id) {                         \
            auto ws = wbundle;                                                       \
            ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id);            \
            size_t n = index / C;                                                    \
            size_t c = index % C;                                                    \
            do_##mode##_pooling_2x2_stride##i##_##func##_nchw44_NEON(                \
                    static_cast<const type*>(src_ptr.get_ptr()) +                    \
                            n * C * IH * IW * 4 + c * IH * IW * 4,                   \
                    static_cast<type*>(dst_ptr.get_ptr()) + n * C * OH * OW * 4 +    \
                            c * OH * OW * 4,                                         \
                    IH, IW, OH, OW, PH, PW, ws);                                     \
        };                                                                           \
        MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN(                                       \
                static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \
    }                                                                                \
    MIDOUT_END();

#define DISPATCH_MODE(type, func, stride)                                              \
    switch (param.mode) {                                                              \
        case Mode::MAX: {                                                              \
            DISPATCH_FUNC(type, func, stride, max);                                    \
            break;                                                                     \
        }                                                                              \
        case Mode::AVERAGE: {                                                          \
            DISPATCH_FUNC(type, func, stride, avg);                                    \
            break;                                                                     \
        }                                                                              \
        default:                                                                       \
            megdnn_throw(                                                              \
                    ssprintf(                                                          \
                            "Unsupport pooling mode %d", static_cast<int>(param.mode)) \
                            .c_str());                                                 \
    }

#define DISPATCH_STRIDE(type, func)                                         \
    switch (SW) {                                                           \
        case 1: {                                                           \
            DISPATCH_MODE(type, func, 1);                                   \
            break;                                                          \
        }                                                                   \
        case 2: {                                                           \
            DISPATCH_MODE(type, func, 2);                                   \
            break;                                                          \
        }                                                                   \
        default:                                                            \
            megdnn_throw(ssprintf("Unsupport stride size %d", SW).c_str()); \
    }

    DISPATCH_STRIDE(int8_t, int8);

#undef DISPATCH_STRIDE
#undef DISPATCH_MODE
#undef DISPATCH_FUNC
}

bool PoolingImpl::AlgoFilter4ModexStridexNCHW44::usable(
        const PoolingKernSizeParam& param) const {
    auto SH = param.stride[0];
    auto SW = param.stride[1];
    auto FH = param.filter[0];
    auto FW = param.filter[1];

    bool avaible = (param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
                    param.src_type.enumv() == DTypeEnum::Int8) &&
                   param.format == Param::Format::NCHW44 &&
                   (param.mode == Mode::MAX || param.mode == Mode::AVERAGE) &&
                   FH == 4 && FW == 4 && SH == SW && (SW == 1 || SW == 2);

    //! Int8 not support average, because its round mode is different form
    //! qint8
    avaible &=
            !(param.src_type.enumv() == DTypeEnum::Int8 && param.mode == Mode::AVERAGE);
    return avaible;
}

void PoolingImpl::AlgoFilter4ModexStridexNCHW44::exec(
        const PoolingKernParam& param) const {
    auto IH = param.isz[0], IW = param.isz[1];
    auto OH = param.osz[0], OW = param.osz[1];
    auto N = param.n, C = param.ic;
    auto PH = param.padding[0];
    auto PW = param.padding[1];
    auto SW = param.stride[0];

    auto src_ptr = param.src_ptr;
    auto dst_ptr = param.dst_ptr;

#define DISPATCH_FUNC(type, func, i, mode)                                           \
    MIDOUT_BEGIN(                                                                    \
            megdnn_arm_common_pooling, midout_iv(10), midout_iv(#func #i##_hash)) {  \
        WorkspaceBundle wbundle = get_bundle_nchw44(param);                          \
        auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, wbundle = wbundle,  \
                    workspace_ptr = param.workspace<dt_byte>()](                     \
                           size_t index, size_t thread_id) {                         \
            auto ws = wbundle;                                                       \
            ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id);            \
            size_t n = index / C;                                                    \
            size_t c = index % C;                                                    \
            do_##mode##_pooling_4x4_stride##i##_##func##_nchw44_NEON(                \
                    static_cast<const type*>(src_ptr.get_ptr()) +                    \
                            n * C * IH * IW * 4 + c * IH * IW * 4,                   \
                    static_cast<type*>(dst_ptr.get_ptr()) + n * C * OH * OW * 4 +    \
                            c * OH * OW * 4,                                         \
                    IH, IW, OH, OW, PH, PW, ws);                                     \
        };                                                                           \
        MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN(                                       \
                static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \
    }                                                                                \
    MIDOUT_END();

#define DISPATCH_MODE(type, func, stride)                                              \
    switch (param.mode) {                                                              \
        case Mode::MAX: {                                                              \
            DISPATCH_FUNC(type, func, stride, max);                                    \
            break;                                                                     \
        }                                                                              \
        case Mode::AVERAGE: {                                                          \
            DISPATCH_FUNC(type, func, stride, avg);                                    \
            break;                                                                     \
        }                                                                              \
        default:                                                                       \
            megdnn_throw(                                                              \
                    ssprintf(                                                          \
                            "Unsupport pooling mode %d", static_cast<int>(param.mode)) \
                            .c_str());                                                 \
    }

#define DISPATCH_STRIDE(type, func)                                         \
    switch (SW) {                                                           \
        case 1: {                                                           \
            DISPATCH_MODE(type, func, 1);                                   \
            break;                                                          \
        }                                                                   \
        case 2: {                                                           \
            DISPATCH_MODE(type, func, 2);                                   \
            break;                                                          \
        }                                                                   \
        default:                                                            \
            megdnn_throw(ssprintf("Unsupport stride size %d", SW).c_str()); \
    }

    DISPATCH_STRIDE(int8_t, int8);

#undef DISPATCH_STRIDE
#undef DISPATCH_MODE
#undef DISPATCH_FUNC
}

bool PoolingImpl::AlgoFilter5ModexStridexNCHW44::usable(
        const PoolingKernSizeParam& param) const {
    auto SH = param.stride[0];
    auto SW = param.stride[1];
    auto FH = param.filter[0];
    auto FW = param.filter[1];

    bool avaible = (param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
                    param.src_type.enumv() == DTypeEnum::Int8) &&
                   param.format == Param::Format::NCHW44 &&
                   (param.mode == Mode::MAX || param.mode == Mode::AVERAGE) &&
                   FH == 5 && FW == 5 && SH == SW && (SW == 1 || SW == 2);
    //! Int8 not support average, because its round mode is different form
    //! qint8
    avaible &=
            !(param.src_type.enumv() == DTypeEnum::Int8 && param.mode == Mode::AVERAGE);
    return avaible;
}

void PoolingImpl::AlgoFilter5ModexStridexNCHW44::exec(
        const PoolingKernParam& param) const {
    auto IH = param.isz[0], IW = param.isz[1];
    auto OH = param.osz[0], OW = param.osz[1];
    auto N = param.n, C = param.ic;
    auto PH = param.padding[0];
    auto PW = param.padding[1];
    auto SW = param.stride[0];

    auto src_ptr = param.src_ptr;
    auto dst_ptr = param.dst_ptr;

#define DISPATCH_FUNC(type, func, i, mode)                                           \
    MIDOUT_BEGIN(                                                                    \
            megdnn_arm_common_pooling, midout_iv(11), midout_iv(#func #i##_hash)) {  \
        WorkspaceBundle wbundle = get_bundle_nchw44(param);                          \
        auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, wbundle = wbundle,  \
                    workspace_ptr = param.workspace<dt_byte>()](                     \
                           size_t index, size_t thread_id) {                         \
            auto ws = wbundle;                                                       \
            ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id);            \
            size_t n = index / C;                                                    \
            size_t c = index % C;                                                    \
            do_##mode##_pooling_5x5_stride##i##_##func##_nchw44_NEON(                \
                    static_cast<const type*>(src_ptr.get_ptr()) +                    \
                            n * C * IH * IW * 4 + c * IH * IW * 4,                   \
                    static_cast<type*>(dst_ptr.get_ptr()) + n * C * OH * OW * 4 +    \
                            c * OH * OW * 4,                                         \
                    IH, IW, OH, OW, PH, PW, ws);                                     \
        };                                                                           \
        MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN(                                       \
                static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \
    }                                                                                \
    MIDOUT_END();

#define DISPATCH_MODE(type, func, stride)                                              \
    switch (param.mode) {                                                              \
        case Mode::MAX: {                                                              \
            DISPATCH_FUNC(type, func, stride, max);                                    \
            break;                                                                     \
        }                                                                              \
        case Mode::AVERAGE: {                                                          \
            DISPATCH_FUNC(type, func, stride, avg);                                    \
            break;                                                                     \
        }                                                                              \
        default:                                                                       \
            megdnn_throw(                                                              \
                    ssprintf(                                                          \
                            "Unsupport pooling mode %d", static_cast<int>(param.mode)) \
                            .c_str());                                                 \
    }

#define DISPATCH_STRIDE(type, func)                                         \
    switch (SW) {                                                           \
        case 1: {                                                           \
            DISPATCH_MODE(type, func, 1);                                   \
            break;                                                          \
        }                                                                   \
        case 2: {                                                           \
            DISPATCH_MODE(type, func, 2);                                   \
            break;                                                          \
        }                                                                   \
        default:                                                            \
            megdnn_throw(ssprintf("Unsupport stride size %d", SW).c_str()); \
    }

    DISPATCH_STRIDE(int8_t, int8);

#undef DISPATCH_STRIDE
#undef DISPATCH_MODE
#undef DISPATCH_FUNC
}

}  // namespace arm_common
}  // namespace megdnn
// vim: syntax=cpp.doxygen
